from argparse import ArgumentParser


class Parser(ArgumentParser):
    def __init__(self, command):
        super().__init__()
        self._add_base_args()

        if command == "dataset":
            self._add_dataset_args()
        elif command == "train":
            self._add_train_args()

    def _add_base_args(self):
        self.add_argument(
            "--env-name",
            type=str,
            default="metaworld.drawer-open-v2",
            help="The name of the environment to collect the dataset",
        )
        self.add_argument(
            "--warmup",
            action="store_true",
            help="Whether to warmup the environment",
        )
        self.add_argument(
            "--goal-resistance",
            type=int,
            default=0,
            help="Whether to warmup the environment",
        )
        self.add_argument(
            "--n-episodes",
            type=int,
            default=10,
            help="The number of episodes to collect",
        )
        self.add_argument(
            "--seed",
            type=int,
            default=777,
            help="The seed for the environment",
        )
        self.add_argument(
            "--video",
            action="store_true",
            help="Whether to save video of the environment",
        )
        self.add_argument("--verbose", action="store_true", help="Printing mode on")
        self.add_argument("--debug", action="store_true", help="Printing mode on")

    def _add_dataset_args(self):
        self.add_argument(
            "--save-path",
            type=str,
            default="datasets",
            help="The path to save the dataset",
        )

    def _add_train_args(self):
        self.add_argument(
            "--dataset-path",
            type=str,
            default="datasets",
            help="The path to load the dataset",
        )
        self.add_argument(
            "--result-path",
            type=str,
            default="results",
            help="The path to save the result",
        )
        self.add_argument(
            "--algo",
            type=str,
            default="diffbc",
            help="The name of the algorithm to train",
        )
        self.add_argument(
            "--delta",
            type=float,
            default=0.05,
            help="The scale for the guidance gradient",
        )
        self.add_argument(
            "--history",
            type=int,
            default=None,
        )
        self.add_argument(
            "--guide",
            type=str,
            default=None,
            help="The name of the guide function for diffusion sampling",
        )
        self.add_argument(
            "--multimodal",
            action="store_true",
            help="Text or Multimodal",
        )
        self.add_argument(
            "--prompt",
            type=str,
            default=None,
            help="The prompt for the guide function",
        )
        self.add_argument(
            "--prompt_idx",
            type=int,
            default=None,
            help="The prompt index for the guide function",
        )
        self.add_argument(
            "--validate-no",
            action="store_false",
        )
        self.add_argument(
            "--train",
            action="store_true",
            help="Whether to train the model",
        )
        self.add_argument(
            "--phase",
            type=int,
            default=1,
            help="Learning phase 0 for single task, 1 for multi task",
        )
        self.add_argument(
            "--test",
            action="store_true",
            help="Whether to test the model",
        )
        self.add_argument(
            "--tag",
            type=str,
            default="v0",
            help="Tag of model",
        )
        self.add_argument(
            "--context",
            action="store_true",
        )
